{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Signal echoing\n",
    "\n",
    "Echoing signal `n` steps is an example of synchronized many-to-many task."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from res.sequential_tasks import EchoData\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "# import torch.nn.functional as F\n",
    "import torch.optim as optim\n",
    "\n",
    "torch.manual_seed(1);"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "batch_size = 5\n",
    "echo_step = 3\n",
    "series_length = 20_000\n",
    "BPTT_T = 20\n",
    "\n",
    "train_data = EchoData(\n",
    "    echo_step=echo_step,\n",
    "    batch_size=batch_size,\n",
    "    series_length=series_length,\n",
    "    truncated_length=BPTT_T,\n",
    ")\n",
    "train_size = len(train_data)\n",
    "\n",
    "test_data = EchoData(\n",
    "    echo_step=echo_step,\n",
    "    batch_size=batch_size,\n",
    "    series_length=series_length,\n",
    "    truncated_length=BPTT_T,\n",
    ")\n",
    "test_size = len(test_data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Let's print first 20 timesteps of the first sequences to see the echo data:\n",
    "print('(1st input sequence)  x:', *train_data.x_batch[0, :20], '... ')\n",
    "print('(1st target sequence) y:', *train_data.y_batch[0, :20], '... ')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# batch_size different sequences are created:\n",
    "print('x_batch:', *(str(d)[1:-1] + ' ...' for d in train_data.x_batch[:, :20]), sep='\\n')\n",
    "print('x_batch size:', train_data.x_batch.shape)\n",
    "print()\n",
    "print('y_batch:', *(str(d)[1:-1] + ' ...' for d in train_data.y_batch[:, :20]), sep='\\n')\n",
    "print('y_batch size:', train_data.y_batch.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# In order to use RNNs data is organized into temporal\n",
    "# chunks of size [batch_size, T, feature_dim]\n",
    "print('x_chunk:', *train_data.x_chunks[0].squeeze(), sep='\\n')\n",
    "print('1st x_chunk size:', train_data.x_chunks[0].shape)\n",
    "print()\n",
    "print('y_chunk:', *train_data.y_chunks[0].squeeze(), sep='\\n')\n",
    "print('1st y_chunk size:', train_data.y_chunks[0].shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class SimpleRNN(nn.Module):\n",
    "    def __init__(self, input_size, rnn_hidden_size, output_size):\n",
    "        super().__init__()\n",
    "        self.rnn_hidden_size = rnn_hidden_size\n",
    "        self.rnn = torch.nn.RNN(\n",
    "            input_size=input_size,\n",
    "            hidden_size=rnn_hidden_size,\n",
    "            num_layers=1,\n",
    "            nonlinearity='relu',\n",
    "            batch_first=True\n",
    "        )\n",
    "        self.linear = torch.nn.Linear(\n",
    "            in_features=rnn_hidden_size,\n",
    "            out_features=1\n",
    "        )\n",
    "\n",
    "    def forward(self, x, hidden):\n",
    "        x, hidden = self.rnn(x, hidden)  \n",
    "        x = self.linear(x)\n",
    "        return x, hidden"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train(hidden):\n",
    "    model.train()\n",
    "       \n",
    "    correct = 0\n",
    "    for batch_idx in range(train_size):\n",
    "        data, target = train_data[batch_idx]\n",
    "        data, target = torch.from_numpy(data).float().to(device), torch.from_numpy(target).float().to(device)\n",
    "        optimizer.zero_grad()\n",
    "        if hidden is not None: hidden.detach_()\n",
    "        logits, hidden = model(data, hidden)\n",
    "        loss = criterion(logits, target)\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        \n",
    "        pred = (torch.sigmoid(logits) > 0.5)\n",
    "        correct += (pred == target.byte()).int().sum().item()\n",
    "        \n",
    "    return correct, loss.item(), hidden"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def test(hidden):\n",
    "    model.eval()   \n",
    "    correct = 0\n",
    "    with torch.no_grad():\n",
    "        for batch_idx in range(test_size):\n",
    "            data, target = test_data[batch_idx]\n",
    "            data, target = torch.from_numpy(data).float().to(device), torch.from_numpy(target).float().to(device)\n",
    "            logits, hidden = model(data, hidden)\n",
    "            \n",
    "            pred = (torch.sigmoid(logits) > 0.5)\n",
    "            correct += (pred == target.byte()).int().sum().item()\n",
    "\n",
    "    return correct"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "feature_dim = 1 #since we have a scalar series\n",
    "h_units = 4\n",
    "\n",
    "model = SimpleRNN(\n",
    "    input_size=1,\n",
    "    rnn_hidden_size=h_units,\n",
    "    output_size=feature_dim\n",
    ").to(device)\n",
    "hidden = None\n",
    "        \n",
    "criterion = torch.nn.BCEWithLogitsLoss()\n",
    "optimizer = torch.optim.RMSprop(model.parameters(), lr=0.001)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "n_epochs = 5\n",
    "epoch = 0\n",
    "\n",
    "while epoch < n_epochs:\n",
    "    correct, loss, hidden = train(hidden)\n",
    "    epoch += 1\n",
    "    train_accuracy = float(correct) / train_size\n",
    "    print(f'Train Epoch: {epoch}/{n_epochs}, loss: {loss:.3f}, accuracy {train_accuracy:.1f}%')\n",
    "\n",
    "#test    \n",
    "correct = test(hidden)\n",
    "test_accuracy = float(correct) / test_size\n",
    "print(f'Test accuracy: {test_accuracy:.1f}%')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Let's try some echoing\n",
    "my_input = torch.empty(1, 100, 1).random_(2).to(device)\n",
    "hidden = None\n",
    "my_out, _ = model(my_input, hidden)\n",
    "print(my_input.view(1, -1).byte(), (my_out > 0).view(1, -1), sep='\\n')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:dl-minicourse] *",
   "language": "python",
   "name": "conda-env-dl-minicourse-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
